class Generator(nn.Module):
def __init__(self, img_size, latent_dim, channels):
super(Generator, self).__init__()
self.init_size = img_size // 8
self.conv_init_dim = 512
self.l1 = nn.Sequential(nn.Linear(latent_dim, self.conv_init_dim * self.init_size ** 2))
self.channels = channels
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(self.conv_init_dim),
nn.Upsample(scale_factor=2),
nn.Conv2d(self.conv_init_dim, self.conv_init_dim, 3, stride=1, padding=1),
nn.BatchNorm2d(self.conv_init_dim, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(self.conv_init_dim, self.conv_init_dim//2, 3, stride=1, padding=1),
nn.BatchNorm2d(self.conv_init_dim//2, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(self.conv_init_dim//2, self.conv_init_dim//2, 3, stride=1, padding=1),
nn.BatchNorm2d(self.conv_init_dim//2, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(self.conv_init_dim//2, self.conv_init_dim//2, 3, stride=1, padding=1),
nn.BatchNorm2d(self.conv_init_dim//2, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(self.conv_init_dim//2, self.conv_init_dim//4, 3, stride=1, padding=1),
nn.BatchNorm2d(self.conv_init_dim//4, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(self.conv_init_dim//4, self.channels, 3, stride=1, padding=1),
nn.Sigmoid(),
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], self.conv_init_dim, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
img_size = 64
latent_dim = 128
channels=3
b_size=40
paths = ['../input/dcgan-generate-anime-faces-4/generator337.pt',
'../input/dcgan-generate-anime-faces-4/generator1348.pt',
'../input/dcgan-generate-anime-faces-4/generator2359.pt',
'../input/dcgan-generate-anime-faces-4/generator3370.pt',
'../input/dcgan-generate-anime-faces-4/generator4381.pt',
'../input/dcgan-generate-anime-faces-4/generator5392.pt',
'../input/dcgan-generate-anime-faces-4/generator6403.pt',
'../input/dcgan-generate-anime-faces-4/generator7414.pt',
'../input/dcgan-generate-anime-faces-4/generator8425.pt',
'../input/dcgan-generate-anime-faces-4/generator9436.pt',
'../input/dcgan-generate-anime-faces-4/generator10110.pt',]
generator = Generator(img_size, latent_dim, channels)
z = np.random.normal(0, 1, (b_size, latent_dim)).tolist()
z = torch.FloatTensor(z)
for path in paths:
state_dict = torch.load(path)
generator.load_state_dict(state_dict)
imgs = generator(z)
print(imgs.size())
imgs = imgs.detach()
plt.figure(figsize=(30,14))
grid = utils.make_grid(imgs, nrow=10)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title(path.split('/')[-1])
plt.show()